/*
 * SPDX-License-Identifier: Apache-2.0
 *
 * The OpenSearch Contributors require contributions made to
 * this file be licensed under the Apache-2.0 license or a
 * compatible open source license.
 */

package org.opensearch.remotestore;

import com.carrotsearch.randomizedtesting.RandomizedTest;

import org.opensearch.action.admin.indices.close.CloseIndexResponse;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.routing.IndexShardRoutingTable;
import org.opensearch.cluster.routing.ShardRouting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.test.BackgroundIndexer;
import org.opensearch.test.InternalTestCluster;
import org.opensearch.test.OpenSearchIntegTestCase;
import org.junit.Before;

import java.util.Locale;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertHitCount;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;

@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 0)
public class ReplicaToPrimaryPromotionIT extends RemoteStoreBaseIntegTestCase {
    private int shard_count = 5;

    @Before
    public void setup() {
        internalCluster().startClusterManagerOnlyNode();
    }

    @Override
    public Settings indexSettings() {
        return Settings.builder().put(super.indexSettings()).put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, shard_count).build();
    }

    public void testPromoteReplicaToPrimary() throws Exception {
        internalCluster().startNode();
        internalCluster().startNode();
        final String indexName = randomAlphaOfLength(5).toLowerCase(Locale.ROOT);
        shard_count = scaledRandomIntBetween(1, 5);
        createIndex(indexName);
        int numOfDocs = 0;
        int numIter = scaledRandomIntBetween(0, 10);
        for (int i = 0; i < numIter; i++) {
            final int numOfDoc = scaledRandomIntBetween(0, 200);
            logger.info("num of docs in iter {} {}", numOfDoc, i);
            if (numOfDoc > 0) {
                try (
                    BackgroundIndexer indexer = new BackgroundIndexer(
                        indexName,
                        "_doc",
                        client(),
                        numOfDoc,
                        RandomizedTest.scaledRandomIntBetween(2, 5),
                        false,
                        null
                    )
                ) {
                    indexer.setUseAutoGeneratedIDs(true);
                    indexer.start(numOfDoc);
                    waitForIndexed(numOfDoc, indexer);
                    numOfDocs += numOfDoc;
                    indexer.stopAndAwaitStopped();
                    if (random().nextBoolean()) {
                        // 90% refresh + 10% flush
                        if (random().nextInt(10) != 0) {
                            refresh(indexName);
                        } else {
                            flush(indexName);
                        }
                    }
                }
            }
        }

        ensureGreen(indexName);

        // sometimes test with a closed index
        final IndexMetadata.State indexState = randomFrom(IndexMetadata.State.OPEN, IndexMetadata.State.CLOSE);
        if (indexState == IndexMetadata.State.CLOSE) {
            CloseIndexResponse closeIndexResponse = client().admin().indices().prepareClose(indexName).get();
            assertThat("close index not acked - " + closeIndexResponse, closeIndexResponse.isAcknowledged(), equalTo(true));
            ensureGreen(indexName);
        }

        // pick up a data node that contains a random primary shard
        ClusterState state = client(internalCluster().getClusterManagerName()).admin().cluster().prepareState().get().getState();
        final int numShards = state.metadata().index(indexName).getNumberOfShards();
        final ShardRouting primaryShard = state.routingTable().index(indexName).shard(randomIntBetween(0, numShards - 1)).primaryShard();
        final DiscoveryNode randomNode = state.nodes().resolveNode(primaryShard.currentNodeId());

        // stop the random data node, all remaining shards are promoted to primaries
        internalCluster().stopRandomNode(InternalTestCluster.nameFilter(randomNode.getName()));
        ensureYellowAndNoInitializingShards(indexName);

        state = client(internalCluster().getClusterManagerName()).admin().cluster().prepareState().get().getState();
        for (IndexShardRoutingTable shardRoutingTable : state.routingTable().index(indexName)) {
            for (ShardRouting shardRouting : shardRoutingTable.activeShards()) {
                assertThat(shardRouting + " should be promoted as a primary", shardRouting.primary(), is(true));
            }
        }

        if (indexState == IndexMetadata.State.CLOSE) {
            assertAcked(client().admin().indices().prepareOpen(indexName));
            ensureYellowAndNoInitializingShards(indexName);
        }
        refresh(indexName);
        assertHitCount(client().prepareSearch(indexName).setSize(0).get(), numOfDocs);
    }

    public void testFailoverWhileIndexing() throws Exception {
        internalCluster().startNode();
        internalCluster().startNode();
        final String indexName = randomAlphaOfLength(5).toLowerCase(Locale.ROOT);
        shard_count = scaledRandomIntBetween(1, 5);
        createIndex(indexName);
        ensureGreen(indexName);
        int docCount = scaledRandomIntBetween(20, 50);
        final int indexDocAfterFailover = scaledRandomIntBetween(20, 50);
        AtomicInteger numAutoGenDocs = new AtomicInteger();
        CountDownLatch latch = new CountDownLatch(1);
        final AtomicBoolean finished = new AtomicBoolean(false);
        Thread indexingThread = new Thread(() -> {
            int docsAfterFailover = 0;
            while (finished.get() == false && numAutoGenDocs.get() < docCount) {
                IndexResponse indexResponse = internalCluster().clusterManagerClient()
                    .prepareIndex(indexName)
                    .setSource("field", numAutoGenDocs.get())
                    .get();

                if (indexResponse.status() == RestStatus.CREATED || indexResponse.status() == RestStatus.OK) {
                    numAutoGenDocs.incrementAndGet();
                    if (numAutoGenDocs.get() == docCount / 2) {
                        if (random().nextInt(3) == 0) {
                            refresh(indexName);
                        } else if (random().nextInt(2) == 0) {
                            flush(indexName);
                        }
                        // Node is killed on this
                        latch.countDown();
                    } else if (numAutoGenDocs.get() > docCount / 2) {
                        docsAfterFailover++;
                        if (docsAfterFailover == indexDocAfterFailover) {
                            finished.set(true);
                        }
                    }
                }
            }
            logger.debug("Done indexing");
        });
        indexingThread.start();
        latch.await();

        ClusterState state = client(internalCluster().getClusterManagerName()).admin().cluster().prepareState().get().getState();
        final int numShards = state.metadata().index(indexName).getNumberOfShards();
        final ShardRouting primaryShard = state.routingTable().index(indexName).shard(randomIntBetween(0, numShards - 1)).primaryShard();
        final DiscoveryNode randomNode = state.nodes().resolveNode(primaryShard.currentNodeId());

        // stop the random data node, all remaining shards are promoted to primaries
        internalCluster().stopRandomNode(InternalTestCluster.nameFilter(randomNode.getName()));
        ensureYellowAndNoInitializingShards(indexName);
        indexingThread.join();
        refresh(indexName);
        assertHitCount(
            client(internalCluster().getClusterManagerName()).prepareSearch(indexName).setSize(0).setTrackTotalHits(true).get(),
            numAutoGenDocs.get()
        );
    }
}
